{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "09abde76",
   "metadata": {},
   "source": [
    "## 8.5 ResNet概述"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7af9412",
   "metadata": {},
   "source": [
    "论文可见于：[论文链接](https://arxiv.org/abs/1512.03385)\n",
    "\n",
    "ResNet是由微软研究院的 Kaiming He, Xiangyu Zhang, Shaoqing Ren 和 Jian Sun 在 2015 年提出的深度神经网络架构。ResNet斩获当年ImageNet竞赛中分类任务第一名，目标检测第一名；COCO数据集中目标检测第一名，图像分割第一名。它的名字来源于残差网络（Residual Network），也就是 ResNet 架构中的神经网络块是建立在残差模块之上的。\n",
    "\n",
    "ResNet 是目前依然非常流行的深度神经网络架构，在计算机视觉领域有着广泛的应用，并为之后的深度神经网络架构的提出和改进奠定了基础。截止目前为止，论文引用次数已经超过14万。\n",
    "\n",
    "ResNet 是为了解决深度神经网络训练时网络退化问题而提出的，所谓的网络退化其实就是深层网络的效果反而比浅层网络更差。ResNet 的提出让我们可以训练出更深的网络，其深度突破了100层，甚至超过了1000层。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91d668bb",
   "metadata": {},
   "source": [
    "### 8.5.1 ResNet基本思想"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "010d72b4",
   "metadata": {},
   "source": [
    "ResNet 的基本思想是引入残差模块来解决深度神经网络训练时的网络退化问题。\n",
    "\n",
    "残差模块的思想是将多层神经网络块看成一个整体，然后学习残差，即将输出与输入的差作为模型的预测目标。这样做的好处是可以使梯度流动更加稳定，从而使深度神经网络得以顺利训练。\n",
    "\n",
    "假设将堆叠的几层神经元称为一个block，那么对于某个block，假设其拟合的函数是 $F(x)$，如果期望潜在映射为 $H(x)$，与其让 $F(x)$ 直接拟合 $H(x)$ ，不如去学习残差 $H(x)-x$，即 $F(x)=H(x)-x$，这样其前向路径就变成了 $F(x)+x$，即 $H(x)=F(x)+x$。其结构如下图所示。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "541f9102",
   "metadata": {},
   "source": [
    "<img src=\"./images/8-5-1.png\" width=\"50%\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c924440",
   "metadata": {},
   "source": [
    "### 8.5.2 ResNet结构"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5fca5b4",
   "metadata": {},
   "source": [
    "ResNet34的结构与VGG19的对比如下图所示：\n",
    "\n",
    "<img src=\"./images/8-5-2.png\" width=\"100%\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "058fb3bc",
   "metadata": {},
   "source": [
    "其中实现实曲线表示残差连接，虚线表示包含升维。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bae443ba",
   "metadata": {},
   "source": [
    "不同深度的ResNet结构如下图所示：\n",
    "\n",
    "<img src=\"./images/8-5-3.png\" width=\"80%\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "167309b3",
   "metadata": {},
   "source": [
    "看上面的结构图好像很复杂，但实际上对照上表就是三部分：\n",
    "\n",
    "1. 第一部分卷积层+最大池化层；\n",
    "2. 第二部分是4组不同数量和规格的残差模块；\n",
    "3. 第三部分平均池化层+全连接层。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bade6fb5",
   "metadata": {},
   "source": [
    "### 8.5.3 ResNet代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf3f3bbf",
   "metadata": {},
   "source": [
    "接下来看一下如何用代码实现相关网络结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d17e5ec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的库\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "967af69c",
   "metadata": {},
   "source": [
    "首先定义两种残差模块。\n",
    "\n",
    "<img src=\"./images/8-5-4.png\" width=\"60%\"></img>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f3017c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        identity = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "\n",
    "        out += identity\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "73730002",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(planes * 4)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        identity = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv3(out)\n",
    "        out = self.bn3(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "\n",
    "        out += identity\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d781d405",
   "metadata": {},
   "source": [
    "然后组装ResNet的网络结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e51e4d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet(nn.Module):\n",
    "\n",
    "    def __init__(self, block, layers, num_classes=1000):\n",
    "        self.inplanes = 64\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0])\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "    # 生成网络结构的函数，根据传入的配置，拼接出对应的网络结构\n",
    "    def _make_layer(self, block, planes, blocks, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
    "            downsample = nn.Sequential(\n",
    "                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(planes * block.expansion),\n",
    "            )\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "\n",
    "        x = self.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "915d0da2",
   "metadata": {},
   "source": [
    "最后和VGGNet一样，封装对应模型的函数，直接调用即可。对应layers参数的组合与上文中表格是一致的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6eb0680e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 封装函数\n",
    "def resnet18():\n",
    "    return ResNet(BasicBlock, [2, 2, 2, 2])\n",
    "\n",
    "def resnet34():\n",
    "    return ResNet(BasicBlock, [3, 4, 6, 3])\n",
    "\n",
    "def resnet50():\n",
    "    return ResNet(Bottleneck, [3, 4, 6, 3])\n",
    "\n",
    "def resnet101():\n",
    "    return ResNet(Bottleneck, [3, 4, 23, 3])\n",
    "\n",
    "def resnet152():\n",
    "    return ResNet(Bottleneck, [3, 8, 36, 3])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5190c45e",
   "metadata": {},
   "source": [
    "### 8.5.4 小结"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a45b44e1",
   "metadata": {},
   "source": [
    "* 针对网络退化问题，采用残差结构解决（隔层相连，弱化每层之间的强联系），利用残差结构让网络能够更深、收敛速度更快、优化更容易，同时参数相对之前的模型更少、复杂度更低。\n",
    "* 对于很深的网络，ResNet使用了更高效的“bottleneck”结构极大程度上降低了参数计算量。\n",
    "* ResNet深度突破了100层，甚至超过了1000层，对后来的深层神经⽹络设计产⽣了深远影响。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
